import ctypes

import ephere_moov as moov
import ephere_ornatrix as ox

from math import sqrt, cos, sin

class Vector3f( ctypes.Structure ):
	_fields_ = ( "data", ctypes.c_float * 3 ),

	def __init__( self, *args ):
		if len( args ) > 3:
			raise ValueError( 'Too many initializers' )
		if len( args ) >= 1:
			if len( args ) == 1:
				if type( args[0] ) == moov.Vector3 or type( args[0] ) == list or type( args[0] ) == ox.Vector3:
					self.data[0] = args[0][0]
					self.data[1] = args[0][1]
					self.data[2] = args[0][2]
				elif type( args[0] ) == Vector3f:
					self.data = args[0].data
				else:
					raise "Invalid argument"
			else:
				self.data[0] = args[0]
		if len( args ) >= 2:
			self.data[1] = args[1]
		if len( args ) == 3:
			self.data[2] = args[2]

	def getx( self ):
		return self.data[0]

	def gety( self ):
		return self.data[1]

	def getz( self ):
		return self.data[2]

	x = property( getx )
	y = property( gety )
	z = property( getz )

	def __getitem__( self, index ):
		return self.data[index]

	def lengthSquared( self ):
		return self.data[0] * self.data[0] + self.data[1] * self.data[1] + self.data[2] * self.data[2]

	def length( self ):
		return sqrt( self.lengthSquared() )

	def dot( self, other ):
		return self.x * other.x + self.y * other.y + self.z * other.z

	def cross( self, other ):
		return Vector3f( self.y * other.z - self.z * other.y, self.z * other.x - self.x * other.z, self.x * other.y - self.y * other.x )

	def normalize( self ):
		return self / self.length()

	def orthogonal( self ):
		"""Returns any orthogonal vector"""
		absx = abs( self.x )
		absy = abs( self.y )
		absz = abs( self.z )
		other = ( Vector3f( 1, 0, 0 ) if absx < absz else Vector3f( 0, 0, 1 ) ) if absx < absy else ( Vector3f( 0, 1, 0 ) if absy < absz else Vector3f( 0, 0, 1 ) )
		return self.cross( other )

	def __repr__( self ):
		return 'Vector3f( %f, %f, %f )' % ( self.data[0], self.data[1], self.data[2] )

	def __len__( self ):
		return 3

	def __iter__( self ):
		return Vector3f.Iterator( self )

	def __eq__( self, other ):
		return self.x == other.x and self.y == other.y and self.z == other.z

	def __ne__( self, other ):
		return not( self == other )

	def __add__( self, other ):
		return Vector3f( self.x + other.x, self.y + other.y, self.z + other.z )

	def __sub__( self, other ):
		return Vector3f( self.x - other.x, self.y - other.y, self.z - other.z )

	def __mul__( self, other ):
		return Vector3f( self.x * other, self.y * other, self.z * other )

	def __div__( self, other ):
		return Vector3f( self.x / other, self.y / other, self.z / other )

	class Iterator:
		def __init__( self, target ):
			self.target = target
			self.current = -1

		def next( self ):
			if self.current == 2:
				raise StopIteration()
			self.current = self.current + 1
			return self.target.data[self.current]

# OpenGL column-major
class Matrix4f( ctypes.Structure ):
	_fields_ = ( "data", ( ctypes.c_float * 4 ) * 4 ),

	def __init__( self, *args ):
		if len( args ) == 0:
			self.data[0][0] = self.data[1][1] = self.data[2][2] = self.data[3][3] = 1
		elif len( args ) == 1:
			if type( args[0] ) == Matrix4f:
				self.data = args[0].data
			elif type( args[0] ).__name__ == 'ndarray':
				self.data = args[0].ctypes.data_as( ctypes.POINTER( ( ctypes.c_float * 4 ) * 4 ) ).contents
			else:
				self.data = args[0]

	def __repr__( self ):
		return 'Matrix4f( [ %f, %f, %f, %f ], [ %f, %f, %f, %f ], [ %f, %f, %f, %f ], [ %f, %f, %f, %f ] )' \
			% ( self.data[0][0], self.data[1][0], self.data[2][0], self.data[3][0], self.data[0][1], self.data[1][1], self.data[2][1], self.data[3][1],
			self.data[0][2], self.data[1][2], self.data[2][2], self.data[3][2], self.data[0][3], self.data[1][3], self.data[2][3], self.data[3][3] )

	def __mul__( self, other ):
		if type( other ) == Vector3f:
			v = Vector3f(
				self.data[0][0] * other.x + self.data[1][0] * other.y + self.data[2][0] * other.z + self.data[3][0],
				self.data[0][1] * other.x + self.data[1][1] * other.y + self.data[2][1] * other.z + self.data[3][1],
				self.data[0][2] * other.x + self.data[1][2] * other.y + self.data[2][2] * other.z + self.data[3][2] )
			return v
		elif type( other ) == Matrix4f:
			m = Matrix4f( self )
			return m
		else:
			raise TypeError( 'Unsupported type for matrix multipllication' )

	def GetRow( self, index ):
		return Vector3f( self.data[0][index], self.data[1][index], self.data[2][index] )

	@staticmethod
	def Scale( factor ):
		m = Matrix4f()
		m.data[0][0] = m.data[1][1] = m.data[2][2] = factor
		return m

	#@staticmethod
	#def XRotation( angle ):
	#	m = Matrix4f()
	#	m.data[0] = m.data[5] = m.data[10] = factor
	#	return m

def isclose( a, b, rel_tol = 1e-09, abs_tol = 0.0 ):
	return abs( a - b ) <= max( rel_tol * max( abs( a ), abs( b ) ), abs_tol )

class Quaternion( object ):

	def __init__( self, *args ):
		if len( args ) == 0:
			self.w = 1
			self.x = self.y = self.z = 0
		elif len( args ) == 1 and type( args[0] ) == moov.Quaternion:
			self.w = args[0].real
			self.x = args[0].imag[0]
			self.y = args[0].imag[1]
			self.z = args[0].imag[2]
		elif len( args ) == 2:
			self.w = args[0]
			self.x = args[1][0]
			self.y = args[1][1]
			self.z = args[1][2]
		elif len( args ) == 4:
			self.w = args[0]
			self.x = args[1]
			self.y = args[2]
			self.z = args[3]
		else:
			raise ValueError( 'Invalid arguments' )

	@staticmethod
	def FromAngleAxis( angle, axis ):
		q = Quaternion( cos( angle / 2 ), Vector3f( axis ).normalize() * sin( angle / 2 ) )
		return q

	@staticmethod
	def ToAngleAxis( q ):
		s = q.w * q.w + q.x * q.x + q.y * q.y + q.z * q.z
		angle = 0 if s < 1e-6 else arccos( q.w / sqrt( s ) ) * 2
		axis = Vector3f( q.x, q.y, q.z )
		axis = Vector3f( 0, 0, 1 ) if axis.lengthSquared() == 0 else axis.normalize()
		return angle, axis

	@staticmethod
	def FromTwoVectors( v1, v2 ):
		crossProduct = v1.cross( v2 )
		dotProduct = v1.dot( v2 )
		k = sqrt( v1.lengthSquared() * v2.lengthSquared() )

		collinearityTest = dotProduct / k
		if isclose( collinearityTest, -1.0, abs_tol = 1e-7 ):
			return Quaternion( 0, v1.orthogonal().normalize() )
		if isclose( collinearityTest, 1.0, abs_tol = 1e-7 ):
			return Quaternion( 1, 0, 0, 0 )

		k += dotProduct
		length = sqrt( k * k + crossProduct.lengthSquared() )
		return Quaternion( k / length, crossProduct / length )


	def __repr__( self ):
		return 'Quaternion( %f, %f, %f, %f )' % ( self.w, self.x, self.y, self.z )

	def __mul__( self, other ):
		return Quaternion(
			self.w * other.w - self.x * other.x - self.y * other.y - self.z * other.z,
			self.w * other.x + self.x * other.w + self.y * other.z - self.z * other.y,
			self.w * other.y + self.y * other.w + self.z * other.x - self.x * other.z,
			self.w * other.z + self.z * other.w + self.x * other.y - self.y * other.x )

	def ToMatrix( self ):
		result = Matrix4f()
		result.data[0][0] = self.w * self.w + self.x * self.x - self.y * self.y - self.z * self.z
		result.data[1][0] = 2 * self.x * self.y - 2 * self.w * self.z
		result.data[2][0] = 2 * self.x * self.z + 2 * self.w * self.y
		result.data[0][1] = 2 * self.x * self.y + 2 * self.w * self.z
		result.data[1][1] = self.w * self.w - self.x * self.x + self.y * self.y - self.z * self.z
		result.data[2][1] = 2 * self.y * self.z - 2 * self.w * self.x
		result.data[0][2] = 2 * self.x * self.z - 2 * self.w * self.y
		result.data[1][2] = 2 * self.y * self.z + 2 * self.w * self.x
		result.data[2][2] = self.w * self.w - self.x * self.x - self.y * self.y + self.z * self.z
		return result
